import torch



if __name__ == '__main__':
    dataset = torch.load('data/GOODPubmed/homophily/processed/covariate.pt')
    graph = dataset[0]
    print(graph['train_masks'])